import torch
import torch.nn as nn


class PlasticityModel(nn.Module):
    def __init__(self, yield_stress: float = 0.30):
        """
        Trainable continuous yield stress parameter for von Mises plasticity correction.

        Args:
            yield_stress (float): yield stress threshold for plastic correction.
        """
        super().__init__()
        self.yield_stress = nn.Parameter(torch.tensor(yield_stress))

    def forward(self, F: torch.Tensor) -> torch.Tensor:
        """
        Compute corrected deformation gradient from input deformation gradient tensor.

        Args:
            F (torch.Tensor): deformation gradient tensor (B, 3, 3).

        Returns:
            F_corrected (torch.Tensor): corrected deformation gradient tensor (B, 3, 3).
        """
        # Compute SVD of F: U, sigma, Vh
        U, sigma, Vh = torch.linalg.svd(F)  # U: (B,3,3), sigma: (B,3), Vh: (B,3,3)

        # Clamp singular values to avoid log(0)
        sigma_clamped = torch.clamp_min(sigma, 1e-6)  # (B,3)

        # Compute logarithm of singular values (principal logarithmic strains)
        epsilon = torch.log(sigma_clamped)  # (B,3)

        # Compute volumetric mean strain
        epsilon_mean = epsilon.mean(dim=1, keepdim=True)  # (B,1)

        # Deviatoric strain (deviation from mean)
        epsilon_dev = epsilon - epsilon_mean  # (B,3)

        # Norm of deviatoric strain, clamp to avoid numerical issues
        epsilon_dev_norm = torch.norm(epsilon_dev, dim=1, keepdim=True).clamp_min(1e-12)  # (B,1)

        # Compute plastic multiplier (excess over yield stress)
        delta_gamma = epsilon_dev_norm - self.yield_stress  # (B,1)

        # Apply plastic correction only if exceeding yield stress
        delta_gamma_clamped = torch.clamp_min(delta_gamma, 0.0)  # (B,1)

        # Calculate shrink factor for deviatoric strains
        shrink_factor = 1.0 - delta_gamma_clamped / epsilon_dev_norm  # (B,1)

        # Correct deviatoric strain by projecting onto yield surface
        epsilon_dev_corrected = epsilon_dev * shrink_factor  # (B,3)

        # Reassemble corrected total logarithmic strains
        epsilon_corrected = epsilon_mean + epsilon_dev_corrected  # (B,3)

        # Exponentiate to get corrected singular values
        sigma_corrected = torch.exp(epsilon_corrected)  # (B,3)

        # Reconstruct corrected deformation gradient: F_corrected = U * diag(sigma_corrected) * Vh
        F_corrected = U @ torch.diag_embed(sigma_corrected) @ Vh  # (B,3,3)

        return F_corrected


class ElasticityModel(nn.Module):
    def __init__(self, youngs_modulus_log: float = 9.82, poissons_ratio_sigmoid: float = 4.07):
        """
        Trainable continuous parameters for Neo-Hookean elasticity.

        Args:
            youngs_modulus_log (float): log of Young's modulus.
            poissons_ratio_sigmoid (float): Poisson's ratio parameter before sigmoid scaling.
        """
        super().__init__()
        self.youngs_modulus_log = nn.Parameter(torch.tensor(youngs_modulus_log))
        self.poissons_ratio_sigmoid = nn.Parameter(torch.tensor(poissons_ratio_sigmoid))

    def forward(self, F: torch.Tensor) -> torch.Tensor:
        """
        Compute Kirchhoff stress tensor from deformation gradient tensor using Neo-Hookean elasticity.

        Args:
            F (torch.Tensor): deformation gradient tensor (B, 3, 3).

        Returns:
            kirchhoff_stress (torch.Tensor): Kirchhoff stress tensor (B, 3, 3).
        """
        B = F.size(0)  # batch size

        # Compute Young's modulus E and Poisson's ratio nu
        E = self.youngs_modulus_log.exp()  # scalar
        nu = self.poissons_ratio_sigmoid.sigmoid() * 0.49  # scalar in (0,0.49)

        # Compute Lamé parameters mu and lambda
        mu = E / (2 * (1 + nu))  # scalar
        lam = E * nu / ((1 + nu) * (1 - 2 * nu))  # scalar

        # Identity tensor I (B,3,3)
        I = torch.eye(3, dtype=F.dtype, device=F.device).unsqueeze(0).expand(B, -1, -1)  # (B,3,3)

        # Compute determinant J of F (B,)
        J = torch.linalg.det(F).clamp_min(1e-12).view(-1, 1, 1)  # (B,1,1)
        logJ = torch.log(J)  # (B,1,1)

        # Compute inverse transpose of F (B,3,3)
        F_inv = torch.inverse(F)  # (B,3,3)
        F_inv_T = F_inv.transpose(1, 2)  # (B,3,3)

        # Compute first Piola-Kirchhoff stress tensor P = mu*(F - F_inv_T) + lam*logJ*F_inv_T
        P = mu * (F - F_inv_T) + lam * logJ * F_inv_T  # (B,3,3)

        # Compute Kirchhoff stress tau = P * F^T
        Ft = F.transpose(1, 2)  # (B,3,3)
        kirchhoff_stress = torch.matmul(P, Ft)  # (B,3,3)

        return kirchhoff_stress
